{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 为算子添加类型关系"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tvm\n",
    "from tvm import relay\n",
    "from tvm.ir.attrs import DictAttrs\n",
    "from tvm.relay.op import op as _op\n",
    "    \n",
    "def infer_mod(mod, annotate_spans=True):\n",
    "    if annotate_spans:\n",
    "        mod = relay.transform.AnnotateSpans()(mod)\n",
    "\n",
    "    mod = relay.transform.InferType()(mod)\n",
    "    return mod\n",
    "\n",
    "\n",
    "def infer_expr(expr):\n",
    "    relay.transform.InferTypeLocal(expr)\n",
    "    return expr\n",
    "\n",
    "\n",
    "def assert_has_type(expr, typ, mod=None):\n",
    "    if not mod:\n",
    "        mod = tvm.IRModule({})\n",
    "\n",
    "    mod[\"main\"] = expr\n",
    "    mod = infer_mod(mod)\n",
    "    checked_expr = mod[\"main\"]\n",
    "    checked_type = checked_expr.checked_type\n",
    "    if checked_type != typ:\n",
    "        raise RuntimeError(\"Type mismatch %s vs %s\" % (checked_type, typ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mSignature:\u001b[0m \u001b[0mtvm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_type_rel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrel_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype_rel_func\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mSource:\u001b[0m   \n",
      "    \u001b[0;32mdef\u001b[0m \u001b[0madd_type_rel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrel_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype_rel_func\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m        \u001b[0;34m\"\"\"Attach the type function corresponding to the return type.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m        Parameters\u001b[0m\n",
      "\u001b[0;34m        ----------\u001b[0m\n",
      "\u001b[0;34m        rel_name : str\u001b[0m\n",
      "\u001b[0;34m            The type relation name to register.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m        type_rel_func : Optional[function (args: List[Type], attrs: Attrs) -> Type]\u001b[0m\n",
      "\u001b[0;34m            The backing relation function which can solve an arbitrary relation on variables.\u001b[0m\n",
      "\u001b[0;34m            Differences with type_rel_func in C++:\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m            1) When type_rel_func is not None\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m               a) OpAddTypeRel on C++ side will adjust type_rel_func with TypeReporter to\u001b[0m\n",
      "\u001b[0;34m                  calling convention of relay type system.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m               b) type_rel_func returns output argument's type, return None means can't\u001b[0m\n",
      "\u001b[0;34m                  infer output's type.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m               c) only support single output operators for now, the last argument is output tensor.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m            2) when type_rel_func is None, will call predefined type_rel_funcs in relay\u001b[0m\n",
      "\u001b[0;34m                   according to ``tvm.relay.type_relation.`` + rel_name.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m        \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m        \u001b[0m_ffi_api\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpAddTypeRel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrel_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype_rel_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFile:\u001b[0m      /media/pc/data/lxw/ai/tvm/python/tvm/ir/op.py\n",
      "\u001b[0;31mType:\u001b[0m      function"
     ]
    }
   ],
   "source": [
    "tvm.ir.Op.add_type_rel??"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "{meth}`tvm.ir.Op.add_type_rel` 函数的作用是将返回类型对应的类型函数附加到关系名称上。\n",
    "\n",
    "参数：\n",
    "- `rel_name`: `str`，要注册的类型关系名称。\n",
    "- `type_rel_func`: 可选的函数，接受参数列表和属性作为输入，返回类型。该函数可以解决变量上的任意关系。与 C++ 中的 `type_rel_func` 的区别如下：\n",
    "    1. 当 `type_rel_func` 不为 `None` 时：\n",
    "        - C++ 端的 `OpAddTypeRel` 将使用 `TypeReporter` 调整 `type_rel_func` 以适应 `relay` 类型系统的调用约定。\n",
    "        - `type_rel_func` 返回输出参数的类型，返回 `None` 表示无法推断输出的类型。\n",
    "        - 目前仅支持单个输出的算子，最后一个参数是输出张量。\n",
    "    2. 当 `type_rel_func` 为 `None` 时，将根据 ``tvm.relay.type_relation.`` + `rel_name` 调用预定义的 `relay` 中的 `type_rel_funcs`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自定义算子类型推断"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "op_name = \"custom_log\"\n",
    "_op.register(op_name, r\"code(cal log of a tensor.)code\")\n",
    "_op.get(op_name).set_num_inputs(1)\n",
    "_op.get(op_name).add_argument(\"data_0\", \"Tensor\", \"The input data tensor.\")\n",
    "# call default relation functions\n",
    "_op.get(op_name).add_type_rel(\"Identity\")\n",
    "_op.get(op_name).set_support_level(1)\n",
    "_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)\n",
    "_op.register_stateful(op_name, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def clog(x):\n",
    "    return relay.Call(_op.get(op_name), [x])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tp = relay.TensorType((10, 10), \"float32\")\n",
    "x = relay.var(\"x\", tp)\n",
    "sb = relay.ScopeBuilder()\n",
    "t1 = sb.let(\"t1\", clog(x))\n",
    "t2 = sb.let(\"t2\", relay.add(t1, x))\n",
    "sb.ret(t2)\n",
    "f = relay.Function([x], sb.get())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%x: Tensor[(10, 10), float32]) {\n",
      "  let %t1 = custom_log(%x);\n",
      "  let %t2 = add(%t1, %x);\n",
      "  %t2\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "fchecked = infer_expr(f)\n",
    "assert fchecked.checked_type == relay.FuncType([tp], tp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%x: Tensor[(10, 10), float32]) {\n",
      "  let %t1 = custom_log(%x);\n",
      "  let %t2 = add(%t1, %x);\n",
      "  %t2\n",
      "} /* ty=fn (Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */\n"
     ]
    }
   ],
   "source": [
    "print(fchecked)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span>) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] {\n",
       "  let <span style=\"color: #AA22FF; font-weight: bold\">%</span>t1: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> custom_log(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span>;\n",
       "  let <span style=\"color: #AA22FF; font-weight: bold\">%</span>t2: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> add(<span style=\"color: #AA22FF; font-weight: bold\">%</span>t1, <span style=\"color: #AA22FF; font-weight: bold\">%</span>x) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span>t2\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod = relay.transform.InferType()(tvm.IRModule.from_expr(f))\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 推断广播 custom_op 的类型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "op_name = \"custom_broadcast_add\"\n",
    "_op.register(op_name, r\"code(Add two tensor with inner broadcasting.)code\")\n",
    "_op.get(op_name).set_num_inputs(2)\n",
    "_op.get(op_name).add_argument(\"data_0\", \"Tensor\", \"The input data tensor.\")\n",
    "_op.get(op_name).add_argument(\"data_1\", \"Tensor\", \"The input data tensor.\")\n",
    "# call default relation functions\n",
    "_op.get(op_name).add_type_rel(\"Broadcast\")\n",
    "_op.get(op_name).set_support_level(1)\n",
    "_op.register_stateful(op_name, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def broadcast_add(x, y):\n",
    "    return relay.Call(_op.get(op_name), [x, y])\n",
    "\n",
    "x = relay.var(\"x\", shape=(10, 4))\n",
    "y = relay.var(\"y\", shape=(5, 10, 1))\n",
    "z = broadcast_add(x, y)\n",
    "func = relay.Function([x, y], z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32]) {\n",
      "  custom_broadcast_add(%x, %y)\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(func)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "t1 = relay.TensorType((10, 4), \"float32\")\n",
    "t2 = relay.TensorType((5, 10, 1), \"float32\")\n",
    "t3 = relay.TensorType((5, 10, 4), \"float32\")\n",
    "expected_ty = relay.FuncType([t1, t2], t3)\n",
    "assert_has_type(func, expected_ty)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">4</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">4</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span>, <span style=\"color: #AA22FF; font-weight: bold\">%</span>y: Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">1</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">1</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span>) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">4</span>), float32] {\n",
       "  custom_broadcast_add(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x, <span style=\"color: #AA22FF; font-weight: bold\">%</span>y) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">5</span>, <span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">4</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span>\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod = relay.transform.InferType()(tvm.IRModule.from_expr(func))\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 推断 custom_op 的类型关系"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def custom_log1_rel(arg_types, attrs):\n",
    "    assert len(arg_types) == 1, \"type relation arg number mismatch!\"\n",
    "    if attrs:\n",
    "        assert isinstance(attrs, DictAttrs)\n",
    "    inputa_type = arg_types[0]\n",
    "    return relay.TensorType(inputa_type.shape, inputa_type.dtype)\n",
    "\n",
    "op_name = \"custom_log1\"\n",
    "_op.register(op_name, r\"code(cal log of a tensor.)code\")\n",
    "_op.get(op_name).set_num_inputs(1)\n",
    "_op.get(op_name).add_argument(\"data_0\", \"Tensor\", \"The input data tensor.\")\n",
    "_op.get(op_name).set_attrs_type_key(\"DictAttrs\")\n",
    "# call customized relation functions\n",
    "_op.get(op_name).add_type_rel(\"custom_log1\", custom_log1_rel)\n",
    "_op.get(op_name).set_support_level(1)\n",
    "_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)\n",
    "_op.register_stateful(op_name, False)\n",
    "\n",
    "def clog(x):\n",
    "    return relay.Call(_op.get(op_name), [x])\n",
    "\n",
    "tp = relay.TensorType((10, 10), \"float32\")\n",
    "x = relay.var(\"x\", tp)\n",
    "sb = relay.ScopeBuilder()\n",
    "t1 = sb.let(\"t1\", clog(x))\n",
    "t2 = sb.let(\"t2\", relay.add(t1, x))\n",
    "sb.ret(t2)\n",
    "f = relay.Function([x], sb.get())\n",
    "fchecked = infer_expr(f)\n",
    "assert fchecked.checked_type == relay.FuncType([tp], tp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #AA22FF\">@main</span>(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span>) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] {\n",
       "  let <span style=\"color: #AA22FF; font-weight: bold\">%</span>t1: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> custom_log1(<span style=\"color: #AA22FF; font-weight: bold\">%</span>x) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span>;\n",
       "  let <span style=\"color: #AA22FF; font-weight: bold\">%</span>t2: Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span> <span style=\"color: #AA22FF; font-weight: bold\">=</span> add(<span style=\"color: #AA22FF; font-weight: bold\">%</span>t1, <span style=\"color: #AA22FF; font-weight: bold\">%</span>x) <span style=\"color: #AA22FF; font-weight: bold\">/*</span> ty<span style=\"color: #AA22FF; font-weight: bold\">=</span>Tensor[(<span style=\"color: #008000\">10</span>, <span style=\"color: #008000\">10</span>), float32] <span style=\"color: #AA22FF; font-weight: bold\">*/</span>;\n",
       "  <span style=\"color: #AA22FF; font-weight: bold\">%</span>t2\n",
       "}\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mod = relay.transform.InferType()(tvm.IRModule.from_expr(f))\n",
    "mod.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 处理推断 custom_op 的类型关系的异常事件"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "参数数量不匹配："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytest\n",
    "def custom_log1_rel(arg_types, attrs):\n",
    "    assert len(arg_types) == 2, \"type relation arg number mismatch!\"\n",
    "    return None\n",
    "\n",
    "op_name = \"custom_log2\"\n",
    "_op.register(op_name, r\"code(cal log of a tensor.)code\")\n",
    "_op.get(op_name).set_num_inputs(1)\n",
    "_op.get(op_name).add_argument(\"data_0\", \"Tensor\", \"The input data tensor.\")\n",
    "_op.get(op_name).set_attrs_type_key(\"DictAttrs\")\n",
    "# call customized relation functions\n",
    "_op.get(op_name).add_type_rel(\"custom_log2\", custom_log1_rel)\n",
    "_op.get(op_name).set_support_level(1)\n",
    "_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)\n",
    "_op.register_stateful(op_name, False)\n",
    "\n",
    "def clog(x):\n",
    "    return relay.Call(_op.get(op_name), [x])\n",
    "\n",
    "tp = relay.TensorType((10, 10), \"float32\")\n",
    "x = relay.var(\"x\", tp)\n",
    "sb = relay.ScopeBuilder()\n",
    "t1 = sb.let(\"t1\", clog(x))\n",
    "t2 = sb.let(\"t2\", relay.add(t1, x))\n",
    "sb.ret(t2)\n",
    "f = relay.Function([x], sb.get())\n",
    "with pytest.raises(AssertionError) as cm:\n",
    "    fchecked = infer_expr(f)\n",
    "    assert \"type relation arg number mismatch\" in str(cm.execption)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "重复注册："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "op_name = \"custom_log3\"\n",
    "_op.register(op_name, r\"code(cal log of a tensor.)code\")\n",
    "with pytest.raises(tvm.error.TVMError) as cm:\n",
    "    _op.register(op_name)\n",
    "    assert \"Operator custom_log3 is registered before\" in str(cm.execption)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py312x",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
